{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e928f302be93763f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from tqdm.auto import tqdm\n",
    "from matplotlib import pyplot as plt\n",
    "import json\n",
    "from internal.dataparsers.colmap_dataparser import ColmapDataParser\n",
    "import internal.utils.colmap as colmap_utils\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "691bf3dde3296bb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_grad_enabled(False)\n",
    "torch.set_printoptions(precision=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9ca4c6a8ec2a1f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from internal.utils.partitioning_utils import SceneConfig, PartitionableScene"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12111ce0923c5c0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "The dataset \"Rubble\" can be downloaded here: https://storage.cmusatyalab.org/mega-nerf-data/rubble-pixsfm.tgz\n",
    "Convert it to colmap format by: `python utils/meganerf2colmap.py ~/data/Mega-NeRF/rubble-pixsfm`\n",
    "Down sample images by: `python utils/image_downsample.py ~/data/Mega-NeRF/rubble-pixsfm/colmap/images --factor 3`\n",
    "\"\"\"\n",
    "\n",
    "dataset_path = os.path.expanduser(\"~/data/Mega-NeRF/rubble-pixsfm/colmap\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "deda9f038e3dfa08",
   "metadata": {},
   "source": [
    "# 1. Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b213e5bcf1b8234",
   "metadata": {},
   "outputs": [],
   "source": [
    "colmap_model = colmap_utils.read_model(os.path.join(dataset_path, \"sparse\"))\n",
    "colmap_model = {\n",
    "    \"cameras\": colmap_model[0],\n",
    "    \"images\": colmap_model[1],\n",
    "    \"points3D\": colmap_model[2],\n",
    "}\n",
    "\n",
    "len(colmap_model[\"cameras\"]), len(colmap_model[\"images\"]), len(colmap_model[\"points3D\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "477bf65f12908096",
   "metadata": {},
   "source": [
    "get camera extrinsics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59000629595448e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "R_list = []\n",
    "T_list = []\n",
    "image_name_list = []\n",
    "image_idx_to_key = []\n",
    "\n",
    "for idx, key in enumerate(colmap_model[\"images\"]):\n",
    "    extrinsics = colmap_model[\"images\"][key]\n",
    "    image_name_list.append(extrinsics.name)\n",
    "\n",
    "    R = torch.tensor(extrinsics.qvec2rotmat(), dtype=torch.float)\n",
    "    T = torch.tensor(extrinsics.tvec, dtype=torch.float)\n",
    "\n",
    "    R_list.append(R)\n",
    "    T_list.append(T)\n",
    "    image_idx_to_key.append(key)\n",
    "\n",
    "R = torch.stack(R_list)\n",
    "T = torch.stack(T_list)\n",
    "\n",
    "assert image_idx_to_key[0] == list(colmap_model[\"images\"].keys())[0]\n",
    "assert image_idx_to_key[-1] == list(colmap_model[\"images\"].keys())[-1]\n",
    "\n",
    "R.shape, T.shape, len(image_idx_to_key),"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8deeff682b20aae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate camera-to-world transform matrix\n",
    "w2c = torch.zeros(size=(R.shape[0], 4, 4), dtype=R.dtype)\n",
    "w2c[:, :3, :3] = R\n",
    "w2c[:, :3, 3] = T\n",
    "w2c[:, 3, 3] = 1.\n",
    "c2w = torch.linalg.inv(w2c)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86191ef6f063706d",
   "metadata": {},
   "source": [
    "get points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df4c7e78ae4e6713",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_point_index = max(colmap_model[\"points3D\"].keys())\n",
    "max_point_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2b54fca39ee807c",
   "metadata": {},
   "outputs": [],
   "source": [
    "point_xyzs = torch.zeros((max_point_index + 1, 3), dtype=torch.float)\n",
    "point_rgbs = torch.zeros((max_point_index + 1, 3), dtype=torch.uint8)\n",
    "point_errors = torch.ones((max_point_index + 1), dtype=torch.float).fill_(255.)\n",
    "point_n_images = torch.zeros((max_point_index + 1), dtype=torch.int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e93a8d36abb9c468",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, point in tqdm(colmap_model[\"points3D\"].items()):\n",
    "    point_xyzs[idx] = torch.from_numpy(point.xyz)\n",
    "    point_rgbs[idx] = torch.from_numpy(point.rgb)\n",
    "    point_errors[idx] = torch.from_numpy(point.error)\n",
    "    point_n_images[idx] = point.image_ids.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c982e2ee3bc47bb4",
   "metadata": {},
   "source": [
    "reorientation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d29be537fd274b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import viser.transforms as vt\n",
    "\n",
    "auto_reorient = False\n",
    "\n",
    "if auto_reorient:\n",
    "    # calculate the up direction of the scene\n",
    "    # NOTE: \n",
    "    #   the calculated direction may not be perfect or even incorrect sometimes, \n",
    "    #   in such a situation, you need to provide a correct up vector\n",
    "    up = -torch.mean(c2w[:, :3, 1], dim=0)\n",
    "    up = up / torch.linalg.norm(up)\n",
    "\n",
    "    rotation = ColmapDataParser.rotation_matrix(up, torch.tensor([0, 0, 1], dtype=up.dtype))\n",
    "    rotation_transform = torch.eye(4, dtype=up.dtype)\n",
    "    rotation_transform[:3, :3] = rotation\n",
    "\n",
    "    # an extra rotation aligning the camera paths to xy axis\n",
    "    extra_rotation = torch.eye(4)\n",
    "    extra_rotation[:3, :3] = torch.from_numpy(vt.SO3.from_z_radians(-np.pi / 8.5).as_matrix())\n",
    "    rotation_transform = extra_rotation @ rotation_transform\n",
    "else:\n",
    "    # In Mega-NeRF, up direction is the -X\n",
    "    rotation_transform = torch.eye(4)\n",
    "    rotation_transform[:3, :3] = torch.from_numpy(vt.SO3.from_z_radians(-np.pi / 8.3).as_matrix() @ vt.SO3.from_x_radians(-np.pi / 2.).as_matrix())\n",
    "    up = rotation_transform[2, :3]\n",
    "\n",
    "up, rotation_transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "316ba52fdc03e9c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "reoriented_camera_centers = c2w[:, :3, 3] @ rotation_transform[:3, :3].T\n",
    "reoriented_point_cloud_xyz = point_xyzs @ rotation_transform[:3, :3].T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d76fdff4474959ba",
   "metadata": {},
   "source": [
    "extract valid points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62b76a56dc0a43df",
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_point_mask = point_n_images > 0\n",
    "valid_reoriented_point_xyzs = reoriented_point_cloud_xyz[valid_point_mask]\n",
    "valid_point_rgbs = point_rgbs[valid_point_mask]\n",
    "len(valid_reoriented_point_xyzs), len(valid_point_rgbs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b49cc2c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# filter out points with large errors\n",
    "min_cameras = 3\n",
    "max_errors = 2.\n",
    "\n",
    "shared_point_mask = torch.logical_and(\n",
    "    torch.ge(point_n_images, min_cameras),\n",
    "    torch.le(point_errors, max_errors),\n",
    ")\n",
    "\n",
    "shared_point_mask.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "949b3dbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene = PartitionableScene(camera_centers=reoriented_camera_centers[..., :2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58ccce2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene.get_bounding_box_by_points(points=reoriented_point_cloud_xyz[shared_point_mask])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a79ca4e9cf3784dd",
   "metadata": {},
   "source": [
    "plot the scene, confirming that it shows in top view"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d049aabd88e0311",
   "metadata": {},
   "outputs": [],
   "source": [
    "sparsify_points = 16\n",
    "fig, ax = plt.subplots()\n",
    "ax.set_aspect('equal', adjustable='box')\n",
    "ax.set_xlim(scene.point_based_bounding_box.min[0], scene.point_based_bounding_box.max[0])\n",
    "ax.set_ylim(scene.point_based_bounding_box.min[1], scene.point_based_bounding_box.max[1])\n",
    "ax.scatter(valid_reoriented_point_xyzs[::sparsify_points, 0], valid_reoriented_point_xyzs[::sparsify_points, 1], c=valid_point_rgbs[::sparsify_points] / 255., s=0.01)\n",
    "ax.scatter(reoriented_camera_centers[:, 0], reoriented_camera_centers[:, 1], s=0.2, c=\"red\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "712443556bd236c9",
   "metadata": {},
   "source": [
    "# 2. Build partitions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9de32d9df7e286fb",
   "metadata": {},
   "source": [
    "choose scene origin and partition size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11c758c560c684bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_config = SceneConfig(\n",
    "    origin=torch.tensor([10., -56.]),\n",
    "    partition_size=40.,\n",
    ")\n",
    "scene.scene_config = scene_config"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ced409320266c45",
   "metadata": {},
   "source": [
    "calculate bounding box and number of partitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e40f69c8cdf7d2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene.get_bounding_box_by_camera_centers()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5770f7706582061",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene.get_scene_bounding_box()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc5b7a3ab989423",
   "metadata": {},
   "source": [
    "plot bounding box"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be500202eb7c7e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene.plot(scene.plot_scene_bounding_box)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1aedb6aa364526c3",
   "metadata": {},
   "source": [
    "build partition coordinates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d082afbc1f0e692",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene.build_partition_coordinates()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb1a73e3dd9c1378",
   "metadata": {},
   "source": [
    "plot partitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1504312c7a3f0a32",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene.plot(scene.plot_partitions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "531a6eba13f8cb8b",
   "metadata": {},
   "source": [
    "# 3. Assign images to partitions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ffd3ca5b68a54f9",
   "metadata": {},
   "source": [
    "## 3.1. Location based assignment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13e18bfed06a0c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_config.location_based_enlarge = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "477e65f3f958f14a",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene.camera_center_based_partition_assignment().sum(-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b36dc940c0bd23f",
   "metadata": {},
   "source": [
    "## 3.2. Visibility based assignment\n",
    "\n",
    "the visibility is calculated from 3D points of every camera"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "633fbba46316ae75",
   "metadata": {},
   "outputs": [],
   "source": [
    "# some parameters may need to be changed\n",
    "scene_config.visibility_based_distance = 1.  # enlarge bounding box by `partition_size * max_visible_distance`, only those cameras inside this enlarged box will be used for visibility based assignment\n",
    "scene_config.visibility_threshold = 1. / 6.\n",
    "# convex hull based visibility\n",
    "scene_config.convex_hull_based_visibility = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a6bf655d48ade76",
   "metadata": {},
   "source": [
    "define image 3D point getter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6792e40f",
   "metadata": {},
   "outputs": [],
   "source": [
    "reliable_points = None\n",
    "if scene_config.convex_hull_based_visibility:\n",
    "    print(\"Convex hull based\")\n",
    "\n",
    "    def project_points(image_idx: int, points, expand: float = 64):\n",
    "        image_key = image_idx_to_key[image_idx]  # the id in colmap sparse model\n",
    "\n",
    "        camera = colmap_model[\"cameras\"][colmap_model[\"images\"][image_key].camera_id]\n",
    "\n",
    "        assert camera.model == \"PINHOLE\"\n",
    "\n",
    "        K = torch.eye(4, dtype=points.dtype, device=points.device)\n",
    "        K[0, 0] = camera.params[0]\n",
    "        K[1, 1] = camera.params[1]\n",
    "        K[0, 2] = camera.params[2]\n",
    "        K[1, 2] = camera.params[3]\n",
    "\n",
    "        full_projection = K @ w2c[image_idx].to(device=points.device)\n",
    "\n",
    "        points_homogenous = points @ full_projection[:3, :3].T + full_projection[:3, 3]\n",
    "        points_front = points_homogenous[points_homogenous[:, -1] > 0.01]\n",
    "        points_uv = (points_front / points_front[:, -1:])[:, :2]\n",
    "\n",
    "        image_shape = torch.tensor([[camera.width, camera.height]], dtype=torch.float, device=points_uv.device)\n",
    "\n",
    "        points_in_image_mask = torch.logical_and(\n",
    "            torch.prod(points_uv >= 0. - expand, dim=-1).bool(),\n",
    "            torch.prod(points_uv < (image_shape + expand), dim=-1).bool(),\n",
    "        )\n",
    "\n",
    "        valid_uv = points_uv[points_in_image_mask]\n",
    "\n",
    "        return torch.min(torch.clamp(valid_uv, min=0.), image_shape).cpu()\n",
    "\n",
    "    reliable_points = point_xyzs[shared_point_mask].cuda()\n",
    "\n",
    "    def get_image_points(image_idx: int):\n",
    "        image_key = image_idx_to_key[image_idx]  # the id in colmap sparse model\n",
    "\n",
    "        # project all points (project non-reoriented points)\n",
    "        projected_points = project_points(image_idx, reliable_points)\n",
    "\n",
    "        # get valid points\n",
    "        points_xys = torch.from_numpy(colmap_model[\"images\"][image_key].xys)\n",
    "        points_ids = torch.from_numpy(colmap_model[\"images\"][image_key].point3D_ids)\n",
    "        valid_mask = points_ids > 0\n",
    "        points_xys = points_xys[valid_mask]\n",
    "        points_ids = points_ids[valid_mask]\n",
    "\n",
    "        # filter\n",
    "        points_ids *= shared_point_mask[points_ids]\n",
    "        filter_mask = points_ids > 0\n",
    "        points_ids = points_ids[filter_mask]\n",
    "        points_xys = points_xys[filter_mask]\n",
    "\n",
    "        return points_xys, reoriented_point_cloud_xyz[points_ids], projected_points\n",
    "else:\n",
    "    print(\"point based\")\n",
    "\n",
    "    def get_image_points(image_idx: int):\n",
    "        image_key = image_idx_to_key[image_idx]\n",
    "        # get valid points\n",
    "        points_ids = torch.from_numpy(colmap_model[\"images\"][image_key].point3D_ids)\n",
    "        points_ids = points_ids[points_ids > 0]\n",
    "\n",
    "        # filter\n",
    "        points_ids *= shared_point_mask[points_ids]\n",
    "        points_ids = points_ids[points_ids > 0]\n",
    "\n",
    "        return reoriented_point_cloud_xyz[points_ids]\n",
    "\n",
    "print(scene.calculate_camera_visibilities(\n",
    "    point_getter=get_image_points,\n",
    "    device=reoriented_point_cloud_xyz.device,\n",
    ").shape)\n",
    "del reliable_points\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b49230f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# preview the convex hull\n",
    "\n",
    "# reliable_points = point_xyzs[shared_point_mask]\n",
    "\n",
    "\n",
    "# def plot_convex_hull(image_idx: int):\n",
    "#     from internal.utils.partitioning_utils import Partitioning\n",
    "#     points_2d, points_3d, projected_points = get_image_points(image_idx)\n",
    "#     visibilities, scene_convex_hull, (partition_convex_hull_list, is_in_bounding_boxes) = Partitioning.calculate_convex_hull_based_visibilities(\n",
    "#         scene.partition_coordinates.get_bounding_boxes(enlarge=0.).to(reoriented_point_cloud_xyz.device),\n",
    "#         points_2d=points_2d,\n",
    "#         points_3d=points_3d[:, :2],\n",
    "#         projected_points=projected_points,\n",
    "#     )\n",
    "\n",
    "#     from PIL import Image\n",
    "#     with Image.open(os.path.join(dataset_path, \"images\", colmap_model[\"images\"][image_idx_to_key[image_idx]].name)) as i:\n",
    "#         fig, ax = plt.subplots()\n",
    "#         ax.imshow(i)\n",
    "#         ax.scatter(projected_points[:, 0], projected_points[:, 1], s=0.1)\n",
    "#         ax.plot(projected_points[scene_convex_hull.vertices, 0], projected_points[scene_convex_hull.vertices, 1], 'r--', lw=2)\n",
    "#         ax.plot(projected_points[scene_convex_hull.vertices[0], 0], projected_points[scene_convex_hull.vertices[0], 1], 'ro')\n",
    "#         fig.show()\n",
    "\n",
    "#         for partition_idx, partition_convex_hull in enumerate(partition_convex_hull_list):\n",
    "#             if partition_convex_hull is None:\n",
    "#                 continue\n",
    "\n",
    "#             fig, ax = plt.subplots()\n",
    "#             ax.imshow(i)\n",
    "#             ax.title.set_text(scene.partition_coordinates.get_str_id(partition_idx))\n",
    "#             partition_feature_points = points_2d[is_in_bounding_boxes[partition_idx]]\n",
    "#             ax.scatter(partition_feature_points[:, 0], partition_feature_points[:, 1], s=0.1)\n",
    "\n",
    "#             for simplex in partition_convex_hull.simplices:\n",
    "#                 ax.plot(partition_feature_points[simplex, 0], partition_feature_points[simplex, 1], 'c')\n",
    "#             ax.plot(\n",
    "#                 partition_feature_points[partition_convex_hull.vertices, 0],\n",
    "#                 partition_feature_points[partition_convex_hull.vertices, 1],\n",
    "#                 'o',\n",
    "#                 mec='r',\n",
    "#                 color='none',\n",
    "#                 lw=1,\n",
    "#                 markersize=10,\n",
    "#             )\n",
    "\n",
    "#             fig.show()\n",
    "\n",
    "#     print(\"visibilities={}\".format(visibilities))\n",
    "#     print(scene_convex_hull.volume)\n",
    "\n",
    "\n",
    "# plot_convex_hull(256)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2657ad1c7d8c16d0",
   "metadata": {},
   "source": [
    "assign cameras to partitions based on visibilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6955eec1079b0e7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_config.visibility_based_distance_enlarge_on_no_location_based = 4.\n",
    "scene_config.visibility_threshold_reduce_on_no_location_based = 4.\n",
    "scene.visibility_based_partition_assignment().sum(dim=-1).reshape(scene.scene_bounding_box.n_partitions[1], scene.scene_bounding_box.n_partitions[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af4a8009",
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge small partitions, you can revert merging by calling `scene.unmerge()`\n",
    "scene.merge_no_location_based_partitions()\n",
    "scene.plot(scene.plot_partitions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c9cbb8b3e147daf",
   "metadata": {},
   "source": [
    "# 4. Preview"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7138bcd1e9ec1e4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_path = os.path.join(dataset_path, scene.build_output_dirname())\n",
    "try:\n",
    "    assert os.path.exists(os.path.join(output_path, \"partitions.pt\")) is False, os.path.join(output_path, \"partitions.pt\")\n",
    "except:\n",
    "    del output_path\n",
    "    raise\n",
    "os.makedirs(output_path, exist_ok=True)\n",
    "output_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed6b811a9221b0c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_plot_points = 51_200\n",
    "plot_point_sparsify = max(valid_reoriented_point_xyzs.shape[0] // max_plot_points, 1)\n",
    "plot_point_sparsify"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4641be5320187594",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in range(len(scene.partition_coordinates)):\n",
    "    scene.save_plot(\n",
    "        scene.plot_partition_assigned_cameras,\n",
    "        os.path.join(output_path, \"{}.png\".format(scene.partition_coordinates.get_str_id(idx))),\n",
    "        idx,\n",
    "        valid_reoriented_point_xyzs,\n",
    "        valid_point_rgbs,\n",
    "        point_sparsify=plot_point_sparsify,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e994b1774ff1754",
   "metadata": {},
   "source": [
    "# 5. Saving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4c4126fa8af3c30",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.load(scene.save(\n",
    "    output_path,\n",
    "    extra_data={\n",
    "        \"up\": up,\n",
    "        \"rotation_transform\": rotation_transform,\n",
    "    }\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e4b4eb0614d9d35",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene.save_plot(scene.plot_partitions, os.path.join(output_path, \"partitions.png\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e4130f935bb6196",
   "metadata": {},
   "outputs": [],
   "source": [
    "is_images_assigned_to_partitions = torch.logical_or(scene.is_camera_in_partition, scene.is_partitions_visible_to_cameras)\n",
    "is_images_assigned_to_partitions.sum(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "916e287e",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene.partition_coordinates.get_bounding_boxes()[19]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33f935f6de383e88",
   "metadata": {},
   "source": [
    "write image lists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce94d0624b099486",
   "metadata": {},
   "outputs": [],
   "source": [
    "written_idx_list = []\n",
    "for partition_idx in tqdm(list(range(is_images_assigned_to_partitions.shape[0]))):\n",
    "    partition_image_indices = is_images_assigned_to_partitions[partition_idx].nonzero().squeeze(-1).tolist()\n",
    "    if len(partition_image_indices) == 0:\n",
    "        continue\n",
    "        \n",
    "    written_idx_list.append(partition_idx)\n",
    "        \n",
    "    camera_list = []\n",
    "    \n",
    "    with open(os.path.join(output_path, \"{}.txt\".format(scene.partition_coordinates.get_str_id(partition_idx))), \"w\") as f:\n",
    "        for image_index in partition_image_indices:\n",
    "            f.write(image_name_list[image_index])\n",
    "            f.write(\"\\n\")\n",
    "            \n",
    "            # below camera list is just for visualization, not for training, so its camera intrinsics are fixed values\n",
    "            color = [0, 0, 255]\n",
    "            if scene.is_partitions_visible_to_cameras[partition_idx][image_index]:\n",
    "                color = [255, 0, 0]\n",
    "            camera_list.append({\n",
    "                \"id\": image_index,\n",
    "                \"img_name\": image_name_list[image_index],\n",
    "                \"width\": 1920,\n",
    "                \"height\": 1080,\n",
    "                \"position\": c2w[image_index][:3, 3].numpy().tolist(),\n",
    "                \"rotation\": c2w[image_index][:3, :3].numpy().tolist(),\n",
    "                \"fx\": 1600,\n",
    "                \"fy\": 1600,\n",
    "                \"color\": color,\n",
    "            })\n",
    "            \n",
    "    with open(os.path.join(\n",
    "            output_path, \n",
    "            f\"cameras-{scene.partition_coordinates.get_str_id(partition_idx)}.json\",\n",
    "    ), \"w\") as f:\n",
    "        json.dump(camera_list, f, indent=4, ensure_ascii=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5652916f37e4e18",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_store_points = 512_000\n",
    "store_point_step = max(valid_point_rgbs.shape[0] // max_store_points, 1)\n",
    "from internal.utils.graphics_utils import store_ply\n",
    "store_ply(os.path.join(output_path, \"points.ply\"), point_xyzs[valid_point_mask][::store_point_step], valid_point_rgbs[::store_point_step])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3f9ca368844d2fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Run below commands to visualize the partitions in web viewer:\\n\")\n",
    "for partition_idx in written_idx_list:\n",
    "    id_str = scene.partition_coordinates.get_str_id(partition_idx)\n",
    "    print(\"python utils/show_cameras.py \\\\\\n    '{}' \\\\\\n    --points='{}' --up {:.3f} {:.3f} {:.3f} \\n\".format(\n",
    "        os.path.join(output_path, \"cameras-{}.json\".format(id_str)),\n",
    "        os.path.join(output_path, \"points.ply\"),\n",
    "        *(up.tolist()),\n",
    "    ))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
